# Runs ResNet in inference mode on the fake_imagenet dataset
#
# Usage:
#   sky launch -c infer resnet_inference_app.yaml
#   sky down infer

name: resnet-inference

resources:
  accelerators:
    V100: 1

file_mounts:
  /tmp/resnet-model-dir: 
    source: s3://mluo-resnet-model-dir


setup: |
  git clone https://github.com/concretevitamin/tpu || true
  cd tpu
  git checkout gpu_train

  . $(conda info --base)/etc/profile.d/conda.sh
  pip install --upgrade pip

  conda activate resnet

  if [ $? -eq 0 ]; then
    echo "conda env exists"
  else
    conda create -n resnet python=3.7 -y
    conda activate resnet
    conda install cudatoolkit=11.0 -y
    pip install tensorflow==2.4.0 pyyaml
    pip install protobuf==3.20
    cd models
    pip install -e .
  fi

run: |
  cd tpu
  . $(conda info --base)/etc/profile.d/conda.sh
  conda activate resnet

  export XLA_FLAGS='--xla_gpu_cuda_data_dir=/usr/local/cuda/'
  python -u models/official/resnet/resnet_main.py --use_tpu=False \
  --mode=infer --data_dir=gs://cloud-tpu-test-datasets/fake_imagenet \
  --model_dir=/tmp/resnet-model-dir --amp --xla --loss_scale=128 \
  --infer_batch_size=8 --infer_steps=10000
